오늘 풀어본 문제는 백준의 2042번 문제1이다. 문제 풀이에 사용한 언어는 C++ 이다.
이 문제의 내용과 조건은 다음과 같다.
어떤 $N$ 개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 $1,2,3,4,5$ 라는 수가 있고, $3$ 번째 수를 $6$ 으로 바꾸고 $2$ 번째부터 $5$ 번째까지 합을 구하라고 한다면 $17$ 을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 $2$ 로 바꾸고 $3$ 번째부터 $5$ 번째까지 합을 구하라고 한다면 $12$ 가 될 것이다.
첫째 줄에 수의 개수 $N$ $(1 \leq N \leq 1,000,000)$ 과 $M$ $(1 \leq M \leq 10,000)$, $K$ $(1 \leq K \leq 10,000)$ 가 주어진다. $M$ 은 수의 변경이 일어나는 횟수이고, $K$ 는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 $N+1$ 번째 줄까지 $N$ 개의 수가 주어진다. 그리고 $N+2$ 번째 줄부터 $N+M+K+1$ 번째 줄까지 세 개의 정수 $a, b, c$ 가 주어지는데, $a$ 가 $1$ 인 경우 $b$ $(1 \leq b \leq N)$ 번째 수를 $c$ 로 바꾸고 $a$ 가 $2$ 인 경우에는 $b$ $(1 \leq b \leq N)$ 번째 수부터 $c$ $(b \leq c \leq N)$ 번째 수까지의 합을 구하여 출력하면 된다.
입력으로 주어지는 모든 수는 $-2^{63}$ 보다 크거나 같고, $2^{63}-1$ 보다 작거나 같은 정수이다.
첫째 줄부터 $K$ 줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 $-2^{63}$ 보다 크거나 같고, $2^{63}-1$ 보다 작거나 같은 정수이다.
이 문제는 세그먼트 트리를 이용해야 하는 문제이다. 예전에 금광 세그먼트 트리를 활용해야 하는 문제를 풀었었는데, 그 코드를 가져와 수정하여 SegmentTree
클래스를 만들었고 이 자료구조를 활용하기로 하였다. (그냥 세그먼트 트리가 아니라 금광 세그먼트 트리 문제를 먼저 풀은건 무시하고 넘어가자.)
코드는 다음과 같이 작성하였다.
#include <bits/stdc++.h>
using namespace std;
template <typename T>
class SegmentTree {
private:
std::vector<T> tree;
std::vector<T> arr;
size_t n;
void build(size_t idx, size_t left, size_t right) {
if (left == right) {
tree[idx] = arr[left];
return;
}
size_t mid = (left + right) / 2;
build(2 * idx, left, mid);
build(2 * idx + 1, mid + 1, right);
const T& leftValue = tree[2 * idx];
const T& rightValue = tree[2 * idx + 1];
tree[idx] = leftValue + rightValue;
}
void update(size_t index, const T& value, size_t node, size_t currentLeft, size_t currentRight) {
if (currentLeft <= index && index <= currentRight) {
T difference = value - arr[index];
tree[node] += difference;
if (currentLeft != currentRight) {
size_t mid = (currentLeft + currentRight) / 2;
update(index, value, 2 * node, currentLeft, mid);
update(index, value, 2 * node + 1, mid + 1, currentRight);
}
}
}
const T query(size_t i, size_t j, size_t node, size_t currentLeft, size_t currentRight) {
if (j < currentLeft || currentRight < i) {
return 0;
}
else if (i <= currentLeft && currentRight <= j) {
return tree[node];
}
else {
size_t mid = (currentLeft + currentRight) / 2;
const T leftValue = query(i, j, 2 * node, currentLeft, mid);
const T rightValue = query(i, j, 2 * node + 1, mid + 1, currentRight);
return leftValue + rightValue;
}
}
public:
SegmentTree(const std::vector<T>& a) : arr(a), n(a.size()) {
tree.resize(4 * n + 1);
build(1, 0, n - 1);
}
void update(size_t index, const T& value) {
update(index, value, 1, 0, n - 1);
arr[index] = value;
}
const T query(size_t i, size_t j) {
if (i == j) {
return arr[i];
}
else {
return query(i, j, 1, 0, n - 1);
}
}
};
int main(void) {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int N, M, K;
cin >> N >> M >> K;
vector<long long int> arr(N);
for (int i = 0; i < N; i++) {
cin >> arr[i];
}
SegmentTree<long long int> segTree(arr);
for (int i = 0; i < M + K; i++) {
int a, b, c;
cin >> a >> b >> c;
if (a == 1) {
segTree.update(b - 1, c);
}
else if (a == 2) {
cout << segTree.query(b - 1, c - 1) << "\n";
}
}
return 0;
}
실행 결과 ‘틀렸습니다’가 나왔다.
처음에는 내가 만든 SegmentTree
클래스의 함수에 문제점이 있나 하고 한참을 이리저리 듈려봤지만, 결국 알아낸 원인은 a
, b
, c
변수가 int
형이었기 때문에 입력 과정에서 문제가 있던 것이었다. 그래서 long long int
로 바꿔주었다.
코드는 다음과 같이 작성하였다.
#include <bits/stdc++.h>
using namespace std;
template <typename T>
class SegmentTree {
private:
std::vector<T> tree;
std::vector<T> arr;
size_t n;
void build(size_t idx, size_t left, size_t right) {
if (left == right) {
tree[idx] = arr[left];
return;
}
size_t mid = (left + right) / 2;
build(2 * idx, left, mid);
build(2 * idx + 1, mid + 1, right);
const T& leftValue = tree[2 * idx];
const T& rightValue = tree[2 * idx + 1];
tree[idx] = leftValue + rightValue;
}
void update(size_t index, const T& value, size_t node, size_t currentLeft, size_t currentRight) {
if (currentLeft <= index && index <= currentRight) {
T difference = value - arr[index];
tree[node] += difference;
if (currentLeft != currentRight) {
size_t mid = (currentLeft + currentRight) / 2;
update(index, value, 2 * node, currentLeft, mid);
update(index, value, 2 * node + 1, mid + 1, currentRight);
}
}
}
const T query(size_t i, size_t j, size_t node, size_t currentLeft, size_t currentRight) {
if (j < currentLeft || currentRight < i) {
return 0;
}
else if (i <= currentLeft && currentRight <= j) {
return tree[node];
}
else {
size_t mid = (currentLeft + currentRight) / 2;
const T leftValue = query(i, j, 2 * node, currentLeft, mid);
const T rightValue = query(i, j, 2 * node + 1, mid + 1, currentRight);
return leftValue + rightValue;
}
}
public:
SegmentTree(const std::vector<T>& a) : arr(a), n(a.size()) {
tree.resize(4 * n + 1);
build(1, 0, n - 1);
}
void update(size_t index, const T& value) {
update(index, value, 1, 0, n - 1);
arr[index] = value;
}
const T query(size_t i, size_t j) {
return query(i, j, 1, 0, n - 1);
}
};
int main(void) {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int N, M, K;
cin >> N >> M >> K;
vector<long long int> arr(N);
for (int i = 0; i < N; i++) {
cin >> arr[i];
}
SegmentTree<long long int> segTree(arr);
for (int i = 0; i < M + K; i++) {
long long int a, b, c;
cin >> a >> b >> c;
if (a == 1) {
segTree.update(b - 1, c);
}
else if (a == 2) {
cout << segTree.query(b - 1, c - 1) << "\n";
}
}
return 0;
}
그러자 모든 테스트 케이스를 통과하고 정답이 나오는 것을 확인할 수 있었다.
예전에 풀었던 금광 세그먼트 트리를 이용하는 문제는 우선 입력을 받은 뒤 트리를 구성하면 구간 내의 값들만 알아내면 되었기 때문에 update
함수는 없어 추가로 구현해야 했다. 이 과정에서 구간에 해당하는 노드들을 모두 업데이트 해주어야 했기 때문에 혹시라도 시간 초과가 발생할까 우려스러웠지만, 다행히도 그런 일은 일어나지 않았다.
그러나 나중에 Lazy Propagation을 이용해야 하는 문제들이 있다고 들었다. 아마 그런 문제들에서 이 자료구조를 그대로 사용했다가는 시간 초과가 날 것이다. 그리고 머지 않아 그런 문제들을 풀어야할 날이 올 것 같다.
오늘 구현한 세그먼트 트리 자료구조는 마찬가지로 레포지토리2에 저장해두었다. 필요한 사람은 링크를 타고 들어가 확인해보면 된다.
오늘의 PS는 여기까지!
1: https://www.acmicpc.net/problem/2042
2: https://github.com/ChoiCube84/baekjoon-solutions/blob/main/C%2B%2B/2042/segment_tree.hpp